{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "torch.set_printoptions(edgeitems=2)\n",
    "\n",
    "\n",
    "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0,\n",
    "                    3.0, -4.0, 6.0, 13.0, 21.0])\n",
    "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,\n",
    "                    33.9, 21.8, 48.4, 60.4, 68.4])\n",
    "t_un = 0.1 * t_u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(t_u,w,b):\n",
    "    return w * t_u + b\n",
    "def loss_fn(t_p,t_c):\n",
    "    squared_diffs = (t_p - t_c)**2\n",
    "    return squared_diffs.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = torch.tensor([1.0,0.0],requires_grad = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = loss_fn(model(t_u,*params),t_c)\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4517.2969,   82.6000])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注意，每次更新完参数都应把梯度置0\n",
    "if params.grad is not None:\n",
    "    params.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_loop(n_epochs,learning_rate,params,t_u,t_c):\n",
    "    for epoch in range(1,n_epochs+1):\n",
    "        if params.grad is not None:\n",
    "            params.grad.zero_()\n",
    "        t_p = model(t_u,*params)\n",
    "        loss = loss_fn(t_p,t_c)\n",
    "        loss.backward()\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            params -= learning_rate*params.grad\n",
    "            \n",
    "        if epoch %500 ==0:\n",
    "            print('Epoch %d,Loss %f' %(epoch,float(loss)))\n",
    "            \n",
    "    return params\n",
    "            \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 500,Loss 7.860116\n",
      "Epoch 1000,Loss 3.828538\n",
      "Epoch 1500,Loss 3.092191\n",
      "Epoch 2000,Loss 2.957697\n",
      "Epoch 2500,Loss 2.933134\n",
      "Epoch 3000,Loss 2.928648\n",
      "Epoch 3500,Loss 2.927830\n",
      "Epoch 4000,Loss 2.927679\n",
      "Epoch 4500,Loss 2.927652\n",
      "Epoch 5000,Loss 2.927647\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([  5.3671, -17.3012], requires_grad=True)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_loop(n_epochs=5000,\n",
    "             learning_rate=1e-2,\n",
    "             params=torch.tensor([1.0,0.0],requires_grad = True),\n",
    "             t_u = t_un,t_c = t_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
